{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 15、Dropout原理以及其TF&Torch&Numpy源码实现\n",
    "\n",
    "- [Dropout 论文](https://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)\n",
    "- [RDropout 论文](https://arxiv.org/pdf/2106.14448)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1622, -0.0000, -1.2924],\n",
       "        [ 1.6719, -0.2125, -0.9680]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用 Drouout 类\n",
    "m = nn.Dropout(p=0.4)\n",
    "input = torch.randn(2, 3)\n",
    "m(input) # 40%的节点置为0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0000, -0.0000, -1.5508],\n",
       "        [ 0.0000, -0.2551, -1.1616]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用 Dropout 函数\n",
    "dropout_func = nn.functional.dropout\n",
    "dropout_func(input, p=0.5, training=True) # 如果不是training,那么dropout无效"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读论文\n",
    "\n",
    "原来机器学习为了防止过拟合，都是使用集成学习，但是在运行的时候很耗费资源，于是hinton提出了了dropout来解决过拟合。\n",
    "\n",
    "训练的时候，dropout会随机丢弃部分神经元，其实是在训练多个不同的网络。\n",
    "\n",
    "![image-20250619165655657](http://assets.hypervoid.top/img/2025/06/19/image-20250619165655657-afa6.png)\n",
    "\n",
    "应用dropout后，部分神经元会只和部分的神经元相连，\n",
    "\n",
    "\n",
    "图二：测试时不失活，我们怎样在测试时逼近效果？\n",
    "测试阶段，乘以一个权重p,保证输出的**期望**相同。（不过为了测试时越简单越好，那么可以改成在训练阶段乘以系数(1-p)）\n",
    "![image-20250619165836319](http://assets.hypervoid.top/img/2025/06/19/image-20250619165836319-040e.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 这里的权重缩放在test阶段进行\n",
    "\n",
    "def train(x, w1, b1, w2, b2, dropout_prob=0.5):\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    # 应用dropout\n",
    "    mask = np.random.binomial(ll1.size, 1-dropout_prob)\n",
    "    ll1 *= mask\n",
    "    # 第二层\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    mask = np.random.binomial(ll2.size, 1-dropout_prob)\n",
    "    ll2 *= mask\n",
    "    return ll2\n",
    "\n",
    "def test(x, w1, b1, w2, b2, dropout_prob):\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    # 应用dropout\n",
    "    ll1 *= 1 - dropout_prob\n",
    "    # 第二层\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    ll2 *= 1 - dropout_prob\n",
    "    return ll2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 这里的权重缩放在 train 阶段进行\n",
    "\n",
    "def train(x, w1, b1, w2, b2, dropout_prob=0.5):\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    # 应用dropout\n",
    "    mask = np.random.binomial(ll1.size, 1-dropout_prob)\n",
    "    ll1 *= mask / (1-dropout_prob)\n",
    "    # 第二层\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    mask = np.random.binomial(ll2.size, 1-dropout_prob)\n",
    "    ll2 *= mask / (1-dropout_prob)\n",
    "    return ll2\n",
    "\n",
    "def test(x, w1, b1, w2, b2):\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    return ll2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## R-Dropout\n",
    "\n",
    "用于解决dropout中训练和测试不太一致的问题。此外对于不同的minibatch,实际上训练的是n个网络。\n",
    "R-Dropout 在训练过程中，将同一个输入样本**两次**通过带有 Dropout 的模型，得到两个不同的输出概率分布，然后通过最小化这两个分布之间的**双向 KL 散度**，来约束这两个由 Dropout 随机生成的子模型，使其输出保持一致。\n",
    "\n",
    "![image-20250619192916306](http://assets.hypervoid.top/img/2025/06/19/image-20250619192916306-61d7.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# R-Dropout 伪代码\n",
    "\n",
    "def train(x, w1, b1, w2, b2, dropout_prob=0.5):\n",
    "    x = torch.cat([x,x], 0)\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    # 应用dropout\n",
    "    mask = np.random.binomial(ll1.size, 1-dropout_prob)\n",
    "    ll1 *= mask / (1-dropout_prob)\n",
    "    # 第二层\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    mask = np.random.binomial(ll2.size, 1-dropout_prob)\n",
    "    ll2 *= mask / (1-dropout_prob)\n",
    "    logits1, logits2 = some_func(ll2)\n",
    "    nll_loss1 = nll(logits1, logits2)\n",
    "    nll_loss2 = nll(logits2, logits1)\n",
    "    kl_loss = kl(logits1, logits2)\n",
    "    loss = nll_loss1 + nll_loss2 + kl_loss\n",
    "    return loss\n",
    "\n",
    "def test(x, w1, b1, w2, b2):\n",
    "    ll1:np.ndarray = np.dot(w1, x) + b1\n",
    "    ll1 = np.maximum(0, ll1)\n",
    "    ll2:np.ndarray = np.dot(w2, ll1) + b2\n",
    "    ll2 = np.maximum(0, ll2)\n",
    "    return ll2\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
